#======================================
#========Monkey Category Chains========
#======================================

#====================
#====Front Matter====
#====================

library(rstan)
library(rethinking)

#========================
#====Custom Functions====
#========================

simchain_density <- function(p){
	if (is.null(dim(p)[1])) {
		p <- matrix(p,nrow=1,ncol=length(p))
	}
	d <- rep(0,dim(p)[1])
	for (i in 1:(dim(p)[2]-1)){
		m <- i*(1-p[,i+1])
		j <- i
		while (j > 0) {
			m <- m*p[,j]
			j <- j-1
		}
		d <- d + m
	}
	d <- d + dim(p)[2]*apply(p,1,prod)
}

#===================
#====Stan Models====
#===================

category_chain_model_string <- "
functions {
	real dsim_log(int[] y, int N, int P, real[,] prob) {
		// Computes the log-likelihood for a process-of-elimination serial task (such as SimChain)
		vector[N] q;
		for (i in 1:N) {
			q[i] <- 0;
			for (j in 1:y[i]) {
				q[i] <- q[i] + log(prob[i,j]);        // Each correct responses in a sequence 1:j during trial i corresponds to success with a probability of p[i,j]
			}
			if (y[i] < P) {
				q[i] <- q[i] + log(1-prob[i,y[i]+1]); // If an error is made, the log-likelihood must be incremented by a failure with a probability of (1-p[i,y[i]+1])
			}
		}
		return sum(q);
	}
}
data{
	int<lower=0> N;              // Number of observations
	int<lower=0> S;              // Number of subjects
	int<lower=0> P;              // Number of list stimuli
	int<lower=0> Slist[N];       // List of subject indices
	int<lower=0> trial[N];       // List of trial numbers
	int<lower=0> progress[N];    // List of trial progressions
}
parameters{
	real mu_f[P];             // Population central tendency for f
	real<lower=0> sigma_f[P]; // Population dispersion for f
	real f_vec[S*P];          // z-scores used to discover subject-level f parameters
}
transformed parameters {
	real f[S,P];              // Subject-level log-odds intercept parameters
	for ( s in 1:S ) {
		for ( j in 1:P ) {
			f[s,j] <- mu_f[j] +  sigma_f[j]*f_vec[s+(j-1)*S]; // Conversion of z-scores to real parameters values
		}
	}
}
model{
	mu_f ~ normal(0, 5);
	sigma_f ~ cauchy(0, 2);
	f_vec ~ normal ( 0 , 1 );  // Implies f[s,j] ~ normal ( mu_f[j] , sigma_f[j] )
	{
		real p[N,P];
		for ( i in 1:N ) {
			for ( j in 1:P ) {
				p[i,j] <- inv_logit(f[Slist[i],j]);   // Convert intercept parameters to probability
			}
		}
		increment_log_prob(dsim_log(progress,N,P,p)); // Compute log-odds given list of trial successes
	}
}
"

category_chain_RT_string <- "
data{
	int<lower=0> N;              // Number of observations
	int<lower=0> S;              // Number of subjects
	int<lower=0> P;              // Number of list stimuli
	int<lower=0> Slist[N];       // List of subject indices
	int<lower=0> position[N];    // List of response positions
	real react[N];               // List of log reaction times
}
parameters{
	real mu_m[P];             // Population central tendency for m_rt
	real mu_s[P];             // Population central tendency for s_rt
	real<lower=0> sigma_m[P]; // Population dispersion for m_rt
	real<lower=0> sigma_s[P]; // Population dispersion for s_rt
	real m_vec[S*P];          // z-scores used to discover subject-level f parameters
	real s_vec[S*P];          // z-scores used to discover subject-level m parameters
}
transformed parameters {
	real m_rt[S,P];           // Subject-level mean log reaction times
	real<lower=0> s_rt[S,P];  // Subject-level st. dev. log reaction times
	for ( s in 1:S ) {
		for ( j in 1:P ) {
			m_rt[s,j] <- mu_m[j] +  sigma_m[j]*m_vec[s+(j-1)*S]; // Conversion of z-scores to real m parameters values
			s_rt[s,j] <- mu_s[j] +  sigma_s[j]*s_vec[s+(j-1)*S]; // Conversion of z-scores to real s parameters values
		}
	}	
}
model{
	mu_m ~ normal(0, 20);
	mu_s ~ normal(0, 20);
	sigma_m ~ cauchy(0, 2);
	sigma_s ~ cauchy(0, 2);
	m_vec ~ normal(0, 1);  // Implies m_rt[s,j] ~ normal ( mu_m[j] , sigma_m[j] )
	s_vec ~ normal(0, 1);  // Implies s_rt[s,j] ~ normal ( mu_s[j] , sigma_s[j] )
	for ( i in 1:N ) {
		increment_log_prob(normal_log(react[i], m_rt[Slist[i],position[i]], s_rt[Slist[i],position[i]]));
	}
}
"

#========================================
#====Set Working Directory, Load Data====
#========================================

MCC <- read.csv("MonkeyCategoryChain.csv")
data <- list(N=length(MCC$Subject), S=length(unique(MCC$Subject)), P=4, T=max(MCC$Trial), Slist=MCC$Subject, trial=MCC$Trial, progress=MCC$Progress)
rt_data <- list(N=4*length(MCC$Subject), S=length(unique(MCC$Subject)), P=4, T=max(MCC$Trial), Slist=rep(MCC$Subject,4), position=c(rep(1,length(MCC$Subject)),rep(2,length(MCC$Subject)),rep(3,length(MCC$Subject)),rep(4,length(MCC$Subject))), react=c(MCC$React1,MCC$React2,MCC$React3,MCC$React4))
rt_data$Slist <- rt_data$Slist[!is.na(rt_data$react)]
rt_data$position <- rt_data$position[!is.na(rt_data$react)]
rt_data$react <- log(rt_data$react[!is.na(rt_data$react)])
rt_data$N <- length(rt_data$Slist)

#==================
#====Deploy Stan===
#==================

c_prior <- list(
	list(
		mu_f = rep(0,data$P),
		sigma_f = rep(1,data$P),
		f_vec = rep(0,data$S*data$P)
	)
)

c_rt_prior <- list(
	list(
		m_vec=rep(0,data$S*data$P),
		s_vec=rep(1,data$S*data$P),
		mu_m=rep(0,data$P),
		mu_s=rep(0,data$P),
		sigma_m=rep(1,data$P),
		sigma_s=rep(1,data$P)
	)
)

#==Warning: This is expected to take 4+ hours and temporarily consume a considerable amount of memory==
samp_w <- 1000
samp_m <- 4000

cat_chain_sim <- stan(model_code=category_chain_model_string, data=data, iter=samp_m+samp_w, warmup=samp_w, chains=1, init=c_prior)
c_prms <- extract(cat_chain_sim,permuted=TRUE)
stan_trace(cat_chain_sim,pars=c("mu_f","sigma_f"))
print(cat_chain_sim,pars=c("mu_f","sigma_f"))

cat_chain_rt_sim <- stan(model_code=category_chain_RT_string, data=rt_data, iter=samp_m+samp_w, warmup=samp_w, chains=1, init=c_rt_prior)
c_rt_prms <- extract(cat_chain_rt_sim,permuted=TRUE)
stan_trace(cat_chain_rt_sim,pars=c("mu_m","mu_s"))
print(cat_chain_rt_sim,pars=c("mu_m","mu_s"))
